import argparse
import datetime
import hashlib
import importlib
import json
import os
from pathlib import Path

import numpy as np
import shrinker as shrinker_module
import yaml
from accelerate import Accelerator
from accelerate.utils import InitProcessGroupKwargs

from lmms_eval.utils import simple_parse_args_string

AVAILABEL_SHRINKER = {"embed": "Embed_Shrinker"}


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--shrinker", type=str, help="The type of shrinker you want to use")
    parser.add_argument("--num_items", type=str, help="The number of items you want in your shrinked dataset")
    parser.add_argument("--tasks", type=str, help="The task you want to shrink. Separate each task with comma, will be parsed in to list")
    parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push the shrinked dataset to hub")
    parser.add_argument("--shrinker_kwargs", type=str, help="In args=xxx,args2=xxx format. Will be parsed into dict")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    shrinker_kwargs = simple_parse_args_string(args.shrinker_kwargs)
    shrinker_name = args.shrinker
    tasks = args.tasks.split(",")
    num_items = args.num_items.split(",")
    assert len(num_items) == 1 or len(num_items) == len(tasks), "Either provide one num items for all task or one num item for each task"
    if len(num_items) == 1:
        num_items = [float(num_items[0])] * len(tasks)
    else:
        num_items = [float(n) for n in num_items]
    push_to_hub = args.push_to_hub
    assert len(num_items) == len(tasks) or len(num_items) == 1, "Either pass in one num_items for whole tasks, or pass in num items for each task"
    assert shrinker_name in AVAILABEL_SHRINKER, f"Unavailable shrinker {shrinker_name}. You can choose from {AVAILABEL_SHRINKER.keys()}"

    kwargs_handler = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=60000))
    accelerator = Accelerator(kwargs_handlers=[kwargs_handler])
    for idx, task in enumerate(tasks):
        shrinker = getattr(shrinker_module, f"{AVAILABEL_SHRINKER[shrinker_name]}")
        shrinker = shrinker(task=task, num_items=num_items[idx], push_to_hub=push_to_hub, name=shrinker_name, **shrinker_kwargs)
        shrinker.shrink()
        accelerator.wait_for_everyone()
